//https://leetcode.cn/problems/cherry-pickup-ii/description/?envType=daily-question&envId=2024-05-07

class Solution {
public:
    int cherryPickup(vector<vector<int>>& grid) {
        int n = grid.size();
        int m = grid[0].size();
#define DP
#ifdef DFS
        // 简单dfs, f(x, y0, y1)代表机器人1在(x,y0) 机器人2在(x,y1) 时候可以获得的最多得分
        function<int(int,int,int)> f = [&](int x, int y0, int y1) {
            if(x == n) return 0;
            static constexpr int dys[] = {1,0,-1};
            
            int ret = 0;
            int score = y0 == y1 ? grid[x][y0] : grid[x][y0] + grid[x][y1];
            for(int dy0 : dys) {
                for(int dy1 : dys) {
                    int ny0 = y0 + dy0;
                    int ny1 = y1 + dy1;
                    if(ny0 < 0 || ny0 >= m || ny1 < 0 || ny1 >= m) continue;
                    ret = max(ret, f(x+1, ny0, ny1) + score);
                }
            }
            return ret;
        };
        return f(0,0,m-1);
#endif

#ifdef DP_NAIVE
        // 自底向上写法，DP
        vector<vector<vector<int>>> dp(n+1, vector<vector<int>>(m, vector<int>(m)));
        for(int x = n-1; ~x; --x) {
            for(int y0 = 0; y0 < m; ++y0) {
                for(int y1 = 0; y1 < m; ++y1) {
                    int score = y0 == y1 ? grid[x][y0] : grid[x][y0] + grid[x][y1];
                    int& ret = dp[x][y0][y1];
                    ret = 0;
                    for(int dy0 = -1; dy0 <= 1; ++dy0) {
                        for(int dy1 = -1; dy1 <= 1; ++dy1) {
                            int ny0 = y0 + dy0;
                            int ny1 = y1 + dy1;
                            if(ny0 < 0 || ny0 >= m || ny1 < 0 || ny1 >= m) continue;
                            ret = max(ret, dp[x+1][ny0][ny1] + score);
                        }
                    }
                }
            }
        }
        return dp[0][0][m-1];
#endif

#ifdef DP
        // 空间压缩
        vector<vector<vector<int>>> dp(2, vector<vector<int>>(m, vector<int>(m)));
        int flag = 0;
        for(int x = n-1; ~x; --x) {
            for(int y0 = 0; y0 < m; ++y0) {
                for(int y1 = 0; y1 < m; ++y1) {
                    int score = y0 == y1 ? grid[x][y0] : grid[x][y0] + grid[x][y1];
                    int& ret = dp[flag][y0][y1];
                    ret = 0;
                    for(int dy0 = -1; dy0 <= 1; ++dy0) {
                        for(int dy1 = -1; dy1 <= 1; ++dy1) {
                            int ny0 = y0 + dy0;
                            int ny1 = y1 + dy1;
                            if(ny0 < 0 || ny0 >= m || ny1 < 0 || ny1 >= m) continue;
                            ret = max(ret, dp[flag^1][ny0][ny1] + score);
                        }
                    }
                }
            }
            flag ^= 1;
        }
        return dp[flag^1][0][m-1];
#endif
    }
};